# encoding:utf-8
from pyspark.sql import SparkSession
from utils.model_util import model_type
from utils.database_util import read_mysql, update_record, \
    query_by_model_id, get_task_data_json, update_task, get_progress_id
from utils.spark_util import read_gp_spark, save_gp_spark
from utils.format_util import cast_float, py_to_java
from utils.mlmodel_util import semantic_filling
from algo.mlmodel_spark import decision_tree_spark, linear_regression_spark, pca_spark, kmeans_spark, auto_regression_spark
import json
import logging
import argparse
import time
import copy
from common.config.config import *

# os.environ["PYSPARK_PYTHON"]="/usr/bin/python3"
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")


def update_model_instance_record(model_id, driver_mem, executer_mem, executer_cores):
    # --------------------------------------update model instance -----------------------------------
    record["driver_mem"] = "0"
    record["executor_mem"] = "0"
    record["executor_cores"] = 0
    record["driver_mem"] = driver_mem
    record["executor_mem"] = executer_mem
    record["executor_cores"] = executer_cores

    series = query_by_model_id(model_id)
    record["id"] = model_id
    record["user_id"] = series["user_id"]
    record["project_id"] = series["project_id"]
    record["algorithm"] = series["algorithm"]
    record["model_type"] = series["model_type"]
    record["name"] = series["name"]
    record["model_desc"] = series["model_desc"]
    record["param"] = series["param"]
    record["source_table"] = series["source_table"]
    record["gmt_modifier"] = series["gmt_modifier"]
    record["gmt_creator"] = series["gmt_creator"]
    record["progress_id"] = series["progress_id"]

    # save_model_instance(record, "model_instance")


def return_output(output, spark):
    logging.info("finished with outputs:")
    if len((output["error"])) == 0:
        output["status"] = 0
    output_json = json.dumps(output)
    logging.info(output_json)
    # TODO for test
    if not spark == None:
        spark.stop()


def kill_cur_pid():
    cur_pid = os.getpid()
    cmd = "kill " + str(cur_pid)
    try:
        os.system(cmd)
        logging.info(cur_pid, "killed")
    except Exception as e:
        logging.info(e)


def spark_entry(sql, output, para_list):
    # TODO assign spark memory core based on data
    # spark = SparkSession \
    #     .Builder() \
    #     .appName("ml_model") \
    #     .config_dev('spark.executor.memory', executer_mem) \
    #     .config_dev("spark.executor.cores", executer_cores) \
    #     .config_dev("spark.driver.memory", driver_mem).getOrCreate()

    spark = SparkSession \
        .Builder() \
        .appName("ml_model") \
        .config('spark.executor.memory', "8g") \
        .config("spark.executor.cores", 4) \
        .config("spark.driver.memory", "1g").getOrCreate()

    spark.sparkContext.setLogLevel("Error")
    sc = spark.sparkContext
    midtime = time.time()
    # -------------------------------------minIO conf--------------------------------------------

    sc._jsc.hadoopConfiguration().set("fs.s3a.access.key", MINIO_ACCESS_KEY)
    sc._jsc.hadoopConfiguration().set("fs.s3a.secret.key", MINIO_SECRET_KEY)
    sc._jsc.hadoopConfiguration().set("fs.s3a.endpoint", MINIO_ADDRESS)
    sc._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")
    sc._jsc.hadoopConfiguration().set("fs.s3a.connection.ssl.enabled", "false")
    sc._jsc.hadoopConfiguration().set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")

    # ---------------------------------Algorithm Task Entry---------------------------------------
    sql = "({}) as tmp".format(sql)
    dataframe = read_gp_spark(spark, sql)

    if algo == "DTC":
        if execution == "predict":
            df_pred = decision_tree_spark.predict(sc, dataframe, para_list, output)
        else:
            model = decision_tree_spark.train(sc, dataframe, para_list, record, output)
    elif algo == "LR":
        if execution == "predict":
            df_pred = linear_regression_spark.predict(sc, dataframe, para_list, output)
        else:
            model, y_pred = linear_regression_spark.train(sc, dataframe, para_list, record, output)
    elif algo == "KM":
        if execution == "predict":
            df_pred = kmeans_spark.predict(sc, dataframe, para_list, output)
        else:
            model = kmeans_spark.train(sc, dataframe, para_list, record, output)
    elif algo == "PCA":
        if execution == "predict":
            df_pred = pca_spark.predict(sc, spark, dataframe, para_list, output)
        else:
            model = pca_spark.train(sc, spark, dataframe, para_list, record, output)
    elif algo == "AUTOREG":
        if execution == "predict":
            df_pred = None
        else:
            model = auto_regression_spark.train(sc, spark, dataframe, para_list, record, output)
    else:
        output = {}
        output["error"] = {"algo name not found"}

    if execution == "predict":
        return midtime, output, df_pred, None
    else:
        return midtime, output, None, model


def record_result():
    if execution == "train":
        logging.info("RECORD")
        logging.info(record)
        update_record(model_id, record, "model")
        # TODO
        update_model_instance_record(model_id, driver_mem, executer_mem, executer_cores)

    if execution == "predict":
        logging.info("update task:" + str(task_id) + ", with instance:" + str(instance_id))
        data_json = get_task_data_json(task_id)

        logging.info("OLD DATAJSON GOT")
        columns_df_pred = df_pred.columns

        data_json["output"] = copy.deepcopy(data_json["input"])

        new_columns = []
        old_columns = data_json["output"][0]["tableCols"]
        for new_column in columns_df_pred:
            if new_column not in old_columns:
                new_columns.append(new_column)

        data_json["output"][0]["tableCols"] = old_columns + new_columns

        semantic_filling(new_columns, data_json, para_list)

        logging.info("SEMANTIC FILLED")

        data_json["output"][0]["tableName"] = target

        logging.info(data_json)

        update_task(data_json, task_id)
        # saveMetaForMysql(output, instance_id)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--id", type=int, required=True, default=0)
    parser.add_argument("--exe", type=str, required=True)
    parser.add_argument("--source", type=str)
    parser.add_argument("--target", type=str, default="pipeline._solid_mlmodel")
    parser.add_argument("--feature_col", type=str)
    parser.add_argument("--label_col", type=str, default="")
    parser.add_argument("--task_id", type=int)
    parser.add_argument("--instance_id", type=int)
    parser.add_argument("--spark_config", type=str, default="4,2g,1g")

    args = parser.parse_args()

    spark_config = args.spark_config.split(',')
    driver_mem = spark_config[2]
    executer_cores = int(spark_config[0])
    executer_mem = spark_config[1]

    model_id = args.id
    target = args.target

    execution = args.exe
    logging.info("it's python time")
    if execution == "kill":
        pid = get_progress_id(model_id)
        cmd = "kill " + str(pid)
        try:
            os.system(cmd)
        except Exception as e:
            logging.info(e)
        kill_cur_pid()

    elif execution == "feature":
        try:
            print(1)
        except Exception as e:
            logging.error("Throw exception: {}".format(e))

    else:
        spark = None
        record = {}
        output = {}
        instance_id = -1
        task_id = -1
        try:
            logging.info("Algorithm Begin")
            meta = read_mysql("select param, source_table from model where id = '{}'".format(model_id))
            logging.info(meta)
            para_list, source_table = meta.values[0]
            para_list = eval(para_list)
            para_list["source"] = source_table
            para_list["model_id"] = model_id
            para_list["execution"] = execution
            logging.info("para_list")
            logging.info(para_list)
            if execution == "predict":
                source = args.source
                if "." not in source:
                    source = "dataset." + source
                para_list["source"] = source
                print(args.feature_col)
                # para_list["feature_col"] = args.feature_col.split(",")
                print(para_list["feature_col"])
                if not args.label_col == "":
                    para_list["label_col"] = args.label_col

            model_saved_path = "s3a://ml-model/{}/{}/{}_{}".format(para_list["usr_id"], model_type(para_list["algo"]),
                                                                   para_list["algo"], model_id)

            output["status"] = 500
            output["error"] = []
            output["message"] = []
            output["result"] = {}
            output["result"]["input_params"] = para_list
            output["result"]["output_params"] = {}
            output["result"]["output_params"]["model_saved_path"] = model_saved_path

            logging.info(para_list)
            algo = para_list["algo"].upper()
            algo_type = model_type(algo)

            if para_list["execution"] == "predict":
                sql = "select sparktime from model where id = '{}'".format(model_id)
                meta = read_mysql(sql)
                logging.info(meta)
                spark_time = meta.values[-1][0]
                logging.info(spark_time)
                instance_id = args.instance_id
                task_id = args.task_id

            if para_list["execution"] == "train":
                record["model_saved_path"] = model_saved_path
                model_saved_path += "_spark"
                record["model_saved_path"] = model_saved_path
                record["status"] = "RUNNING"
                # update_record(model_id, record, "model")

            if execution == "train" and (algo_type == "classification" or algo_type == "regression"):
                columns = set(para_list["feature_col"] + ["_record_id_"] + [para_list["label_col"]])
                cols = ",".join(columns)
            else:
                cols = "*"
            sql = "select {} from {}".format(cols, para_list["source"])

            starttime = time.time()

            try:
                midtime, output, df_pred, model = spark_entry(sql, output, para_list)

            except Exception as e:
                error_message = 'Training error: {}'.format(e)
                logging.error(error_message)
                record["status"] = "FAILED"
                other_info = {}
                other_info["error"] = error_message
                other_info = py_to_java(str(other_info)).replace('"', "'").replace("'", '"').replace(" '0'", '(0)')
                record["other_info"] = other_info
                record_result()
                kill_cur_pid()
            # ------------------------------------------SAVE--------------------------------------------
            if bool(output["error"]):
                record["status"] = "FAILED"
            else:
                if execution == "train":
                    y_pred_table_name = "ml_model.output_" + str(model_id)
                    logging.info("OUTPUT_TABLE SAVING")
                    print("SPARK DOES NOT SAVE THE OUTPUT")
                    # y_pred.show(5)
                    # save_gp_spark(y_pred, y_pred_table_name)
                    logging.info("OUTPUT_TABLE SAVED")
                    endtime = time.time()
                    algotime = endtime - midtime
                    sparktime = midtime - starttime
                    record["algotime"] = algotime
                    record["sparktime"] = sparktime
                    record["runtime"] = algotime + sparktime + 0.2
                    record["other_info"]["message"] = str(output["message"])
                    record["other_info"]["out_table"] = y_pred_table_name
                    record = cast_float(record)
                    other_info = py_to_java(str(record["other_info"]))
                    # TODO
                    record["other_info"] = other_info
                    logging.info("SAVING MODEL")
                    model.write().overwrite().save(model_saved_path)
                    logging.info("MODEL SAVED")

                elif execution == "predict":
                    timestamp = int(time.time())
                    output["result"]["output_params"]["target"] = target
                    logging.info("saving df_pred to " + target)
                    save_gp_spark(df_pred, target)
                record["status"] = "FINISHED"
            return_output(output, spark)

        except Exception as e:
            error_message = 'Throw exception: {}'.format(e)
            logging.error(error_message)
            record["status"] = "FAILED"
            other_info = {}
            other_info["error"] = error_message
            other_info = py_to_java(str(other_info)).replace('"', "'").replace("'", '"').replace(" '0'", '(0)')
            record["other_info"] = other_info
        record_result()
